{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TASK 2 (2017-2019) - Main Results\n",
    "\n",
    "\n",
    "import sys\n",
    "import os\n",
    "import logging\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from joblib import dump, load\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import views\n",
    "from views import Ensemble, Model, Downsampling, Period\n",
    "from views.utils.data import assign_into_df\n",
    "from views.apps.transforms import lib as translib\n",
    "from views.apps.evaluation import lib as evallib, feature_importance as fi\n",
    "from views.apps.model import api\n",
    "from views.apps.extras import extras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = views.DATASETS[\"cm_africa_imp_0\"]\n",
    "df = dataset.df\n",
    "#print(df)\n",
    "level = \"cm\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"./models/eval_real/{sub}\"\n",
    "out_paths = {\n",
    "    \"evaluation2\": model_path.format(sub=\"evaluation2\"),\n",
    "    \"features2\": model_path.format(sub=\"features2\")\n",
    "}\n",
    "for k, v in out_paths.items():\n",
    "    if not os.path.isdir(v):\n",
    "        os.makedirs(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import our data\n",
    "wiki = pd.read_csv(r'/home/default/Dokumente/TRINITY/comp/prediction-project/data/Wiki_Data_Final_2021.csv')\n",
    "multi = wiki.set_index(['month_id', 'country_id'])\n",
    "multi = multi.sort_index(level=0)\n",
    "\n",
    "df = df.merge(multi, on=['month_id', 'country_id'], how='left')\n",
    "df = df.fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Partition Dataset in Training, Test and Calibration set\n",
    "period_calib = api.Period(\n",
    "   name=\"calib\", \n",
    "   train_start=121,   # 1990-01\n",
    "   train_end=408,     # 2013.12\n",
    "   predict_start=409, # 2014.01\n",
    "   predict_end=443,   # 2016.11\n",
    ")\n",
    "\n",
    "\n",
    "period_test = api.Period(\n",
    "   name=\"test\", \n",
    "   train_start=121,   # 1990-01\n",
    "   train_end=443,     # 2016.11\n",
    "   predict_start=445, # 2017.01\n",
    "   predict_end=480,   # 2019.12\n",
    ")\n",
    "\n",
    "\n",
    "periods = [period_calib, period_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The steps to train, predict and evaluate for.\n",
    "steps = [2,3,4,5,6,7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Subset \"most important\" features from the ViEWS-dataset and Wikipedia variables\n",
    "\n",
    "cols_features = df[['ged_best_sb','wdi_vc_btl_deth', 'reign_precip','reign_prev_conflict','reign_tenure_months',\n",
    "           'reign_irregular','reign_lastelection','reign_loss','reign_pctile_risk','reign_couprisk',\n",
    "          'ged_count_sb','ged_best_os','ged_count_os','wdi_eg_use_pcap_kg_oe','ged_best_ns',\n",
    "           'vdem_v2x_accountability', 'wdi_nv_srv_totl_zs','wdi_sp_pop_totl','wdi_sl_tlf_totl_fe_zs',\n",
    "          'wdi_sm_pop_totl_zs', 'wdi_sm_pop_refg_or', 'wdi_dt_oda_odat_pc_zs', 'ged_count_ns', \n",
    "           'reign_gov_dominant_party', 'reign_age', 'vdem_v2xpe_exlpol', 'wdi_ag_lnd_frst_k2', \n",
    "           'wdi_bg_gsr_nfsv_gd_zs', 'wdi_st_int_rcpt_xp_zs', 'vdem_v2x_clpol', 'vdem_v2x_genpp', 'viewsEnglish']]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify number of estimators in Random Forest\n",
    "n_estimators = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Model\n",
    "\n",
    "task2_delta = api.Model(\n",
    "    name=\"task2_delta\",                \n",
    "    col_outcome=\"ln_ged_best_sb\",    \n",
    "    cols_features=cols_features,     \n",
    "    steps=steps,                     \n",
    "    outcome_type=\"real\",             \n",
    "    periods=periods,                 \n",
    "    estimator=RandomForestRegressor( \n",
    "        n_estimators=n_estimators,\n",
    "        criterion=\"mse\",\n",
    "        n_jobs=-1,\n",
    "    ),\n",
    "    delta_outcome=True,            \n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [task2_delta]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 14min 59s, sys: 14.9 s, total: 15min 14s\n",
      "Wall time: 20min 1s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Train all models\n",
    "for model in models:\n",
    "    model.fit_estimators(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Store predictions and calibrated predictions for all models in our dataframe\n",
    "for model in models:\n",
    "    df_predictions = model.predict(df)\n",
    "    df = assign_into_df(df, df_predictions)\n",
    "    df_predictions = model.predict_calibrated(\n",
    "        df=df,\n",
    "        period_calib = period_calib,\n",
    "        period_test = period_test\n",
    "    )\n",
    "    df = assign_into_df(df, df_predictions)\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate all models. Scores are stored in the model object\n",
    "for model in models:\n",
    "    model.evaluate(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select test partition and compute performance\n",
    "\n",
    "partition = \"test\"\n",
    "\n",
    "for model in models:\n",
    "    for calib in [\"uncalibrated\"]:\n",
    "        scores = {\n",
    "            \"Step\":[], \n",
    "            \"MSE\":[]\n",
    "        }\n",
    "        if model.delta_outcome:\n",
    "            scores.update({\"TADDA\":[]}) \n",
    "            \n",
    "        for key, value in model.scores[partition].items():\n",
    "            if key != \"sc\":\n",
    "                scores[\"Step\"].append(key)\n",
    "                scores[\"MSE\"].append(value[calib][\"mse\"])\n",
    "                if model.delta_outcome:\n",
    "                    scores[\"TADDA\"].append(value[calib][\"tadda_score\"])\n",
    "\n",
    "        out = pd.DataFrame(scores)\n",
    "        tex = out.to_latex(index=False)\n",
    "\n",
    "        # Add meta.\n",
    "        now = datetime.now().strftime(\"%Y/%m/%d %H:%M:%S\")\n",
    "        meta = f\"\"\"\n",
    "        %Output created by wb_models.ipynb.\n",
    "        %Evaluation of {model.col_outcome} per step.\n",
    "        %Run on selected {model.name} features at {level} level.\n",
    "        %Produced on {now}, written to {out_paths[\"evaluation2\"]}.\n",
    "        \\\\\n",
    "        \"\"\"\n",
    "        tex = meta + tex\n",
    "        path_out = os.path.join(\n",
    "            out_paths[\"evaluation2\"], \n",
    "            f\"{model.name}_{level}_{calib}_scores.tex\"\n",
    "        )\n",
    "   #     with open(path_out, \"w\") as f:\n",
    "    #        f.write(tex)\n",
    "     #   print(f\"Wrote scores table to {path_out}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   Step       MSE     TADDA\n",
      "0     2  0.708021  0.547771\n",
      "1     3  0.726208  0.544299\n",
      "2     4  0.702513  0.521719\n",
      "3     5  0.816718  0.582015\n",
      "4     6  0.793014  0.578380\n",
      "5     7  0.812194  0.572862\n"
     ]
    }
   ],
   "source": [
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Permutation Feature Importance\n",
    "\n",
    "sort_step = 3\n",
    "top = 35\n",
    "\n",
    "for model in models:\n",
    "    for step in steps:\n",
    "        pi_dict = model.extras.permutation_importances[\"test\"][step][\"test\"]\n",
    "        step_df = pd.DataFrame(fi.reorder_fi_dict(pi_dict))\n",
    "        step_df = step_df.rename(columns={\"importance\": f\"s={step}\"})\n",
    "        step_df.set_index(\"feature\", inplace=True)\n",
    "        pi_df = pi_df.join(step_df) if step > steps[0] else step_df.copy()\n",
    "    \n",
    "    pi_df = pi_df.sort_values(by=[f\"s={sort_step}\"], ascending=False)\n",
    "    pi_df = pi_df[0:top + 1]\n",
    "    \n",
    "    fi.write_fi_tex(\n",
    "        pi_df, \n",
    "        os.path.join(out_paths[\"features2\"], f\"impurity_imp_{model.name}_{level}.tex\")\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                               s=2       s=3       s=4       s=5       s=6  \\\n",
      "feature                                                                      \n",
      "wdi_vc_btl_deth           0.599052  0.514813  0.561183  0.562391  0.496475   \n",
      "reign_prev_conflict       0.245835  0.236472  0.228877  0.233958  0.199716   \n",
      "wdi_sm_pop_refg_or        0.154551  0.216374  0.264844  0.237946  0.219555   \n",
      "wdi_eg_use_pcap_kg_oe     0.154314  0.189194  0.201533  0.152955  0.132044   \n",
      "wdi_sm_pop_totl_zs        0.097629  0.138119  0.133283  0.134504  0.115202   \n",
      "ged_count_sb              0.176722  0.093262  0.174465 -0.005942  0.088451   \n",
      "wdi_bg_gsr_nfsv_gd_zs     0.062903  0.067768  0.110718  0.105336  0.047460   \n",
      "wdi_sp_pop_totl           0.038970  0.039520  0.046241  0.046715  0.043488   \n",
      "reign_loss                0.025009  0.035238  0.038774  0.055270  0.045674   \n",
      "wdi_nv_srv_totl_zs        0.018672  0.033444  0.046084  0.041778  0.025236   \n",
      "wdi_st_int_rcpt_xp_zs     0.033498  0.028210  0.022700  0.040039  0.012364   \n",
      "ged_count_os              0.040196  0.026717  0.019570  0.031898  0.019799   \n",
      "vdem_v2xpe_exlpol         0.031997  0.016095  0.023286  0.048773  0.035014   \n",
      "reign_lastelection        0.007996  0.015188  0.019321  0.007008  0.009469   \n",
      "wdi_dt_oda_odat_pc_zs     0.003308  0.015143  0.034296  0.027532  0.018660   \n",
      "ged_best_os               0.039080  0.014697  0.023772  0.011296  0.011706   \n",
      "reign_irregular           0.022939  0.013824  0.018034  0.019077  0.002460   \n",
      "wdi_ag_lnd_frst_k2        0.016190  0.011206  0.011451  0.018466  0.017534   \n",
      "wdi_sl_tlf_totl_fe_zs     0.013350  0.010599  0.013113  0.008097  0.014243   \n",
      "reign_precip              0.005105  0.010513  0.012621  0.005194  0.000359   \n",
      "reign_couprisk            0.013035  0.009642  0.002196 -0.001900  0.002003   \n",
      "vdem_v2x_genpp            0.016541  0.009348  0.013093  0.019444  0.018680   \n",
      "vdem_v2x_clpol            0.006483  0.009321  0.017033  0.012826  0.002380   \n",
      "reign_tenure_months       0.002142  0.008627  0.026638  0.018635  0.029238   \n",
      "ged_best_ns              -0.001122  0.004148  0.003381  0.004266  0.028375   \n",
      "reign_pctile_risk         0.012582  0.003387 -0.000258 -0.007612  0.006380   \n",
      "reign_gov_dominant_party  0.005581  0.001905  0.005744  0.001435  0.003660   \n",
      "vdem_v2x_accountability   0.012131  0.000938  0.010862  0.005208  0.006890   \n",
      "reign_age                 0.000914 -0.000707  0.002437 -0.000351  0.007953   \n",
      "ged_count_ns              0.005083 -0.001231  0.001180  0.002971  0.007146   \n",
      "views                     0.008482 -0.001678  0.000656  0.007034  0.003474   \n",
      "ged_best_sb              -0.645739 -0.706826 -0.566393 -0.651182 -0.608210   \n",
      "\n",
      "                               s=7  \n",
      "feature                             \n",
      "wdi_vc_btl_deth           0.432715  \n",
      "reign_prev_conflict       0.250681  \n",
      "wdi_sm_pop_refg_or        0.265968  \n",
      "wdi_eg_use_pcap_kg_oe     0.141667  \n",
      "wdi_sm_pop_totl_zs        0.159369  \n",
      "ged_count_sb              0.119762  \n",
      "wdi_bg_gsr_nfsv_gd_zs     0.167353  \n",
      "wdi_sp_pop_totl           0.052666  \n",
      "reign_loss                0.040124  \n",
      "wdi_nv_srv_totl_zs        0.038704  \n",
      "wdi_st_int_rcpt_xp_zs     0.025442  \n",
      "ged_count_os              0.005185  \n",
      "vdem_v2xpe_exlpol         0.011348  \n",
      "reign_lastelection        0.016290  \n",
      "wdi_dt_oda_odat_pc_zs     0.025431  \n",
      "ged_best_os               0.002926  \n",
      "reign_irregular           0.017175  \n",
      "wdi_ag_lnd_frst_k2        0.018286  \n",
      "wdi_sl_tlf_totl_fe_zs     0.020101  \n",
      "reign_precip              0.006857  \n",
      "reign_couprisk            0.002372  \n",
      "vdem_v2x_genpp            0.025628  \n",
      "vdem_v2x_clpol           -0.000660  \n",
      "reign_tenure_months       0.040006  \n",
      "ged_best_ns               0.010054  \n",
      "reign_pctile_risk        -0.001867  \n",
      "reign_gov_dominant_party  0.001016  \n",
      "vdem_v2x_accountability   0.010387  \n",
      "reign_age                 0.005238  \n",
      "ged_count_ns              0.004729  \n",
      "views                    -0.006550  \n",
      "ged_best_sb              -0.606874  \n"
     ]
    }
   ],
   "source": [
    "print(pi_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
